{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "otnqBdyaYBrf"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/elementwise.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WUJKQgtPlpyp"
      },
      "source": [
        "# Examples of automatic nonlinearity NNGP/NTK computation using `stax.Elementwise` and `stax.ElementwiseNumerical`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e76rfN10CKHn"
      },
      "source": [
        "For details, please see \"[Fast Neural Kernel Embeddings for General Activations](https://arxiv.org/abs/2209.04121)\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sl2JyRE1hK-z"
      },
      "source": [
        "# Imports and setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 56,
          "status": "ok",
          "timestamp": 1660802702887,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "rmJ6T0nFlcwA",
        "outputId": "5a1a166c-22b7-4743-841b-2664fb5d748d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/bin/sh: line 1: pip: command not found\n",
            "/bin/sh: line 1: pip: command not found\n",
            "/bin/sh: line 1: pip: command not found\n"
          ]
        }
      ],
      "source": [
        "!pip install -q --upgrade pip\n",
        "!pip install -q jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
        "!pip install -q git+https://www.github.com/google/neural-tangents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2-Y93-C7lPOC"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "from jax import numpy as jnp, random\n",
        "from neural_tangents import stax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8o90BE__iJVS"
      },
      "outputs": [],
      "source": [
        "key1, key2, key_init = random.split(random.PRNGKey(1), 3)\n",
        "x1 = random.normal(key1, (3, 2))\n",
        "x2 = random.normal(key2, (4, 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z0QxD2fTmp0x"
      },
      "source": [
        "# 1. Using `Elementwise` to automatically derive the NTK in closed form from the NNGP."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3UDv9K_Nm5cs"
      },
      "source": [
        "`stax.Elementwise` derives under the hood the NTK function from the NNGP function using autodiff."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 3231,
          "status": "ok",
          "timestamp": 1660802716905,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "jRVet8A0nOHO",
        "outputId": "f71af7c2-4082-4c78-b255-54536929ed99"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.0\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/export/hda3/borglet/remote_hdd_fs_dirs/0.colab_kernel_brain_frameworks_gpu_romann.kernel.romann.2612478628965.14b334fb3717c109/mount/server/ml_notebook.runfiles/google3/third_party/py/neural_tangents/_src/stax/elementwise.py:757: UserWarning: Using JAX autodiff to compute the `fn` derivative for NTK. Beware of https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "# Hand-derived NTK expression for the sine nonlinearity.\n",
        "_, _, kernel_fn_manual = stax.serial(stax.Dense(1),\n",
        "                                     stax.Sin())\n",
        "\n",
        "# NNGP function for the sine nonlinearity:\n",
        "def nngp_fn(cov12, var1, var2):\n",
        "  sum_ = (var1 + var2)\n",
        "  s1 = jnp.exp((-0.5 * sum_ + cov12))\n",
        "  s2 = jnp.exp((-0.5 * sum_ - cov12))\n",
        "  return (s1 - s2) / 2\n",
        "\n",
        "# Let the `Elementwise` derive the NTK function in closed form automatically.\n",
        "_, _, kernel_fn = stax.serial(stax.Dense(1),\n",
        "                              stax.Elementwise(nngp_fn=nngp_fn))\n",
        "\n",
        "\n",
        "k_auto = kernel_fn(x1, x2, 'ntk')\n",
        "k_manual = kernel_fn_manual(x1, x2, 'ntk')\n",
        "\n",
        "# The two kernels match!\n",
        "print(jnp.max(jnp.abs(k_manual - k_auto)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AEyUoK_Un8y3"
      },
      "source": [
        "# 2. Using `ElementwiseNumerical` to approximate kernels given only the nonlinearity."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vBL3s42Dn8y4"
      },
      "source": [
        "`stax.ElementwiseNumerical` approximates the NNGP and NTK using Gaussian quadrature and autodiff."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 6559,
          "status": "ok",
          "timestamp": 1660802723600,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "symjuXs9ogsg",
        "outputId": "91770432-0a31-4761-c218-c5055a183f1c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.825128e-05\n",
            "8.523464e-05\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/export/hda3/borglet/remote_hdd_fs_dirs/0.colab_kernel_brain_frameworks_gpu_romann.kernel.romann.2612478628965.14b334fb3717c109/mount/server/ml_notebook.runfiles/google3/third_party/py/neural_tangents/_src/stax/elementwise.py:809: UserWarning: Numerical Activation Layer with fn=\u003cfunction gelu at 0x7fbf94fef940\u003e, deg=25 used!Note that numerical error is controlled by `deg` and for a giventolerance level, required `deg` will highly be dependent on the choiceof `fn`.\n",
            "  warnings.warn(\n",
            "/export/hda3/borglet/remote_hdd_fs_dirs/0.colab_kernel_brain_frameworks_gpu_romann.kernel.romann.2612478628965.14b334fb3717c109/mount/server/ml_notebook.runfiles/google3/third_party/py/neural_tangents/_src/stax/elementwise.py:819: UserWarning: Using JAX autodiff to compute the `fn` derivative for NTK. Beware of https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "# A nonlinearity with a known closed-form expression (GeLU).\n",
        "_, _, kernel_fn_closed_form = stax.serial(\n",
        "  stax.Dense(1),\n",
        "  stax.Gelu(),  # Contains the closed-form GeLU NNGP/NTK expression.\n",
        "  stax.Dense(1)\n",
        ")\n",
        "kernel_closed_form = kernel_fn_closed_form(x1, x2)\n",
        "\n",
        "# Construct the layer from only the elementwise forward-pass GeLU.\n",
        "_, _, kernel_fn_numerical = stax.serial(\n",
        "  stax.Dense(1),\n",
        "  stax.ElementwiseNumerical(jax.nn.gelu, deg=25),  # quadrature and autodiff.\n",
        "  stax.Dense(1)\n",
        ")\n",
        "kernel_numerical = kernel_fn_numerical(x1, x2)\n",
        "\n",
        "# The two kernels are close!\n",
        "print(jnp.max(jnp.abs(kernel_closed_form.nngp - kernel_numerical.nngp)))\n",
        "print(jnp.max(jnp.abs(kernel_closed_form.ntk - kernel_numerical.ntk)))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "elementwise.ipynb",
      "provenance": [
        {
          "file_id": "1tu1K3EDnK5LKkTMK-SWFuJb-vJ3ZZvbl",
          "timestamp": 1660801805990
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
